import os
import sys
import logging
import signal
import random
from argparse import ArgumentParser

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from torchsummary import summary

from ignite.engine.engine import Engine, State, Events
from ignite.handlers import ModelCheckpoint, EarlyStopping
from ignite._utils import convert_tensor

from utils import Experiment
from utils.factory import *

logging.basicConfig(level=logging.INFO, format='')
logger = logging.getLogger()

torch.backends.cudnn.benchmark = True

def main(config):
    assert validate_config(config), "ERROR: Config file is invalid. Please see log for details."

    logger.info("INFO: {}".format(config.toDict()))

    # Set the random number generator seed for torch, as we use their dataloaders this will ensure shuffle is constant
    # Remeber to seed custom datasets etc with the same seed
    if config.seed > 0:
        torch.backends.cudnn.deterministic = True
        torch.cuda.manual_seed_all(config.seed)
        torch.manual_seed(config.seed)
        random.seed(config.seed)
        np.random.seed(config.seed)

    if config.device == "cpu" and torch.cuda.is_available():
        logger.warning("WARNING: Not using the GPU")
    elif config.device == "cuda":
        config.device = f"cuda:{config.device_ids[0]}"

    config.nsave = config.nsave if "nsave" in config else 5

    logger.info("INFO: Creating datasets and dataloaders...")
    # Create the training dataset
    dset_train = create_dataset(config.datasets.train)

    # 如果验证配置有一个名为split的参数，那么我们向训练数据集询问验证数据集。
    # 应注意，如果是这种情况，则不应在训练数据集的init中洗牌数据集，
    # 因为只有在get_validation_split上，我们才知道如何分割数据，除非洗牌是确定的。
    # If the validation config has a parameter called split then we ask the training dset for the validation dataset
    # it should be noted that you shouldn't shuffle the dataset in the init of the train dataset if this is the case
    # as only on get_validation_split will we know how to split the data. Unless shuffling is deterministic.
    train_ids = None
    if 'validation' in config.datasets:
        # Ensure we have a full config for validation, this means we don't need t specify everything in the config file
        # only the differences
        config_val = config.datasets.train.copy()
        config_val.update(config.datasets.validation)

        dset_val = create_dataset(config_val)

        loader_val = get_data_loader(dset_val, config_val)
        print("Using validation dataset of {} samples or {} batches".format(len(dset_val), len(loader_val)))
    elif 'includes_validation' in config.datasets.train:
        train_ids, val_ids = dset_train.get_validation_split(config_val)
        loader_val = get_data_loader(dset_train, config.datasets.train, val_ids)
        print("Using validation dataset of {} samples or {} batches".format(len(val_ids), len(loader_val)))
    else:
        logger.warning("WARNING: No validation dataset was specified")
        dset_val = None
        loader_val = None

    loader_train = get_data_loader(dset_train, config.datasets.train, train_ids)
    dset_len = len(train_ids) if train_ids is not None else len(dset_train)
    print("Using training dataset of {} samples or {} batches".format(dset_len, len(loader_train)))

    cp_paths = None
    last_epoch = 0
    if 'checkpoint' in config:
        checkpoint_dir = config.checkpoint_dir if 'checkpoint_dir' in config else config.result_path
        cp_paths, last_epoch = config.get_checkpoints(path=checkpoint_dir, tag=config.checkpoint)
        print("Found checkpoint {} for Epoch {}".format(config.checkpoint, last_epoch))
        last_epoch = last_epoch if config.resume_from == -1 else config.resume_from
        # config.epochs = config.epochs - last_epoch if last_epoch else config.epochs

    # 加载模型
    models = {} # 包含每个网络和对象的字典
    for name, model in config.model.items():
        logger.info("INFO: Building the {} model".format(name))
        models[name] = build_model(model)

        # Load the checkpoint
        if name in cp_paths:
            models[name].load_state_dict( torch.load( cp_paths[name] ) )
            logger.info("INFO: Loaded model {} checkpoint {}".format(name, cp_paths[name]))

        # 多卡
        if len(config.device_ids) > 1:
            models[name] = nn.DataParallel(models[name], device_ids=config.device_ids)

        models[name].to(config.device)
        print(models[name])

        if 'debug' in config and config.debug is True:
            print("*********** {} ************".format(name))
            for name, param in models[name].named_parameters():
                if param.requires_grad:
                    print(name, param.data)

    # 优化器
    optimizers = {}
    for name, conf in config.optimizer.items():
        optim_conf = conf.copy()
        del optim_conf["models"]

        model_params = []
        for model_id in conf.models:
            model_params.extend( list(filter(lambda p: p.requires_grad, models[model_id].parameters())) )
        
        logger.info("INFO: Using {} Optimization for {}".format(list(optim_conf.keys())[0], name))
        optimizers[name] = get_optimizer(model_params, optim_conf)

        # Restoring the optimizer breaks because we do not include all parameters in the optimizer state. So if we aren't continuing training then just make a new optimizer
        if name in cp_paths and 'checkpoint_dir' not in config:
            optimizers[name].load_state_dict( torch.load( cp_paths[name] ) )
            logger.info("INFO: Loaded {} optimizer checkpoint {}".format(name, cp_paths[name]))

            for state in optimizers[name].state.values():
                for k, v in state.items():
                    if isinstance(v, torch.Tensor):
                        state[k] = v.to(config.device)

    # 损失
    losses = {}
    for name, fcns in config.loss.items():
        losses[name] = []
        for l in fcns:
            losses[name].append( get_loss(l) )
            assert losses[name][-1], "Loss function {} for {} could not be found, please check your config".format(l, name)

    # tensorboard
    if 'logger' in config:
        logger.info("INFO: Initialising the experiment logger")
        exp_logger = get_experiment_logger(config.result_path, config.logger)
        if last_epoch > 0: # 上一次的epoch编号
            exp_logger.fast_forward(last_epoch, len(loader_train))

    # 训练器
    logger.info("INFO: Creating training manager and configuring callbacks")
    trainer = get_trainer(models, optimizers, losses, exp_logger, config)

    trainer_engine = Engine(trainer.train)
    evaluator_engine = Engine(trainer.evaluate)

    trainer.attach("train_loader", loader_train)
    trainer.attach("validation_loader", loader_val)
    trainer.attach("evaluation_engine", evaluator_engine)
    trainer.attach("train_engine", trainer_engine)

    # 没metrics参数
    for phase in config.metrics.keys():
        if phase == "train": engine = trainer_engine
        if phase == "validation": engine = evaluator_engine

        for name, metric in config.metrics[phase].items():
            metric = get_metric(metric)
            if metric is not None:
                metric.attach(engine, name)
            else:
                logger.warning("WARNING: Metric {} could not be created for {} phase".format(name, phase))

    # 注册默认回调以运行验证阶段
    # Register default callbacks to run the validation stage
    if loader_val is not None:
        # 如果很长，就一个epoch验证四次
        if len(loader_train) > 2000:
            num_batch = len(loader_train)//4

            def validate_run(engine):
                if engine.state.iteration % num_batch == 0:
                    evaluator_engine.run(loader_val)

            trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, validate_run)
        else:
            trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: evaluator_engine.run(loader_val))

    # 从checkpoint开始初始化epoch - 这是个容易出错的地方，因为ignite无法提示错误
    # Initialise the Epoch from the checkpoint - this is a hack because Ignite is dumb
    if last_epoch > 0:
        def set_epoch(engine, last_epoch):
            engine.state.epoch = last_epoch

        trainer_engine.add_event_handler(Events.STARTED, set_epoch, last_epoch)

    # 调整学习率，暂未使用学习率调整策略
    '''
    参数的写法是 
    "scheduler": {
        "优化器名称": {
            "lr_scheduler名称": {
                "scheme": "batch"/"epoch", 
                其他参数作为lr_scheduler的实参列表
            }
        },
        "优化器名称": {
            "lr_scheduler名称": {
                "scheme": "batch"/"epoch", # scheme不做函数参数适用于定义事件的
                其他参数作为lr_scheduler的实参列表
            } 
        },
        ...
    }
    '''
    schedulers = {"batch": {}, "epoch": {}}
    if 'scheduler' in config:
        for sched_name, sched in config.scheduler.items():
            if sched_name in optimizers:
                logger.info("INFO: Setting up LR scheduler for {}".format(sched_name))
                # lr策略器和对应的"batch"/"epoch"
                sched_fn, sched_scheme = get_lr_scheduler(optimizers[sched_name], sched)
                assert sched_fn, "Learning Rate scheduler for {} could not be found, please check your config".format(sched_name)
                assert sched_scheme in ["batch", "epoch"], "ERROR: Invalid scheduler scheme, must be either epoch or batch"

                schedulers[sched_scheme][sched_name] = sched_fn

        def epoch_scheduler(engine):
            for name, sched in schedulers["epoch"].items():
                sched.step()

        def batch_scheduler(engine):
            for name, sched in schedulers["batch"].items():
                sched.step()

        trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: batch_scheduler(engine))
        trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: epoch_scheduler(engine))

    if exp_logger is not None:
        trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, exp_logger.log_iteration, phase="train", models=models, optims=optimizers)
        trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, exp_logger.log_epoch, phase="train", models=models, optims=optimizers)
        evaluator_engine.add_event_handler(Events.ITERATION_COMPLETED, exp_logger.log_iteration, phase="evaluate", models=models, optims=optimizers)
        evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, exp_logger.log_epoch, phase="evaluate", models=models, optims=optimizers)

    # 终止训练条件
    if "monitor" in config and config.monitor.early_stopping:
        logger.info("INFO: Enabling early stopping, monitoring {}".format(config.monitor.score))
        score_fn = lambda e: config.monitor.scale * e.state.metrics[config.monitor.score]
        # 如果score_function持续patience个events没有提升则停止训练
        es_handler = EarlyStopping(patience=config.monitor.patience, score_function=score_fn, trainer=trainer_engine)
        evaluator_engine.add_event_handler(Events.COMPLETED, es_handler)

    # 保存最好的模型
    if "monitor" in config and config.monitor.save_score:
        logger.info("INFO: Saving best model based on {}".format(config.monitor.save_score))
        score_fn = lambda e: config.monitor.save_scale * e.state.metrics[config.monitor.save_score]
        # 保存score_fn最好的模型，后面几个参数无需改动
        ch_handler = ModelCheckpoint(config.result_path, 'best_checkpoint', score_function=score_fn, score_name=config.monitor.save_score, n_saved=1, require_empty=False, save_as_state_dict=True)
        to_save = dict(models, **optimizers)
        evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, ch_handler, to_save)

    # 按频率保存
    if config.save_freq > 0:
        ch_handler = ModelCheckpoint(config.result_path, 'checkpoint', n_saved=config.nsave, require_empty=False, save_as_state_dict=True)
        to_save = dict(models, **optimizers)
        trainer_engine.add_event_handler(
            Events.EPOCH_COMPLETED(every=config.save_freq), 
            ch_handler, to_save
        )

    # Register custom callbacks with the engines
    if check_if_implemented(trainer, "on_iteration_start"):
        trainer_engine.add_event_handler(Events.ITERATION_STARTED, trainer.on_iteration_start, phase="train")
        evaluator_engine.add_event_handler(Events.ITERATION_STARTED, trainer.on_iteration_start, phase="evaluate")
    if check_if_implemented(trainer, "on_iteration_end"):
        trainer_engine.add_event_handler(Events.ITERATION_COMPLETED, trainer.on_iteration_end, phase="train")
        evaluator_engine.add_event_handler(Events.ITERATION_COMPLETED, trainer.on_iteration_end, phase="evaluate")
    if check_if_implemented(trainer, "on_epoch_start"):
        trainer_engine.add_event_handler(Events.EPOCH_STARTED, trainer.on_epoch_start, phase="train")
        evaluator_engine.add_event_handler(Events.EPOCH_STARTED, trainer.on_epoch_start, phase="evaluate")
    if check_if_implemented(trainer, "on_epoch_end"):
        trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, trainer.on_epoch_end, phase="train")
        evaluator_engine.add_event_handler(Events.EPOCH_COMPLETED, trainer.on_epoch_end, phase="evaluate")

    # Save the config for this experiment to the results directory, once we know the params are good
    config.save()

    def signal_handler(sig, frame):
        print('You pressed Ctrl+C!')
        if exp_logger is not None:
            exp_logger.teardown()
            sys.exit(0)

    # ?不明白
    signal.signal(signal.SIGINT, signal_handler)

    logger.info("INFO: Starting training...")
    trainer_engine.run(loader_train, max_epochs=config.epochs)

    if exp_logger is not None:
        exp_logger.teardown()

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('-c', '--config', default=None, type=str, required=True, help='config file path (default: None)')
    parser.add_argument('--checkpoint', default=None, type=str, help='Checkpoint tag to reload')    
    parser.add_argument('--checkpoint_dir', default=None, type=str, help='Checkpoint directory to reload')
    parser.add_argument('--suffix', default=None, type=str, help='Add to the name')
    parser.add_argument('--epochs', default=None, type=int, help='Number of epochs')
    parser.add_argument('--resume_from', default=None, type=int, help='Epoch to resume from, allows using checkpoints as initialisation')
    args = parser.parse_args()

    OVERLOADABLE = ['checkpoint', 'epochs', 'checkpoint_dir', 'resume_from']

    overloaded = {}
    for k, v in vars(args).items():
        if (k in OVERLOADABLE) and (v is not None):
            overloaded[k] = v

    config = Experiment.load_from_path(args.config, overloaded, args.suffix)

    assert config, "Config could not be loaded."

    # Else load the saved config from the results dir or throw an error if one doesn't exist
    if len(config.checkpoint) > 0:
        logger.warning("WARNING: --config specifies resuming, overriding config with exising experiment config.")
        resume_config = Experiment(config.name, desc=config.desc, result_dir=config.result_dir).load()
        assert resume_config is not None, "No experiment {} exists, cannot resume training".format(config.name)
        config = resume_config
        assert config, "Config could not be loaded for resume"
    # If we have resume_from in the config but have it < 0 to start a fresh training run then throw and error if the directory already exists
    elif config.overwrite is False:
        assert not config.exists(), "Results directory {} already exists! Please specify a new experiment name or the remove old files.".format(config.result_path)
    else:
        empty_folder(config.result_path)

    main(config)
